#!/usr/bin/env python3
# G18_v2 — Re-center Mini-Pair (common interior window + heavy averaging)
# Present-act (stdlib only). Control: deterministic DDA 1/r per shell (rate_num integer), rails OFF.
# Readouts: slope on log-bins (mid-60%); plateau CV on equal-Δr full bins (outer fraction).
# Key changes vs v1:
#   • Use ONE common interior radial window for both centers:
#       r_max_glob = min(edge_radius(A), edge_radius(B)) - outer_margin
#     so both centers are averaged over the exact same shells and bin edges.
#   • Heavier averaging: larger H and wider linear bins to suppress lattice oscillations.

import argparse, csv, hashlib, json, math, os, sys
from datetime import datetime, timezone
from typing import Dict, List, Tuple

# ---------- utils ----------
def utc_timestamp() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")

def ensure_dirs(root: str, subs: List[str]) -> None:
    for s in subs: os.makedirs(os.path.join(root, s), exist_ok=True)

def write_text(path: str, txt: str) -> None:
    with open(path, "w", encoding="utf-8") as f: f.write(txt)

def json_dump(path: str, obj: dict) -> None:
    with open(path, "w", encoding="utf-8") as f: json.dump(obj, f, indent=2, sort_keys=True)

def sha256_file(path: str) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1<<20), b""): h.update(chunk)
    return h.hexdigest()

def isqrt(n: int) -> int:
    return int(math.isqrt(n))

# ---------- geometry ----------
def build_shell_counts(N:int, cx:int, cy:int) -> Dict[int,int]:
    shells: Dict[int,int] = {}
    for y in range(N):
        for x in range(N):
            dx = x - cx; dy = y - cy
            r = isqrt(dx*dx + dy*dy)
            shells[r] = shells.get(r, 0) + 1
    return shells

def edge_radius(cx:int, cy:int, N:int) -> int:
    # Largest integer radius whose circle is entirely inside the grid (interior window guard)
    return min(cx, cy, (N-1)-cx, (N-1)-cy)

# ---------- control: DDA 1/r per shell ----------
def simulate_dda_1_over_r(shell_counts: Dict[int,int], H:int, rate_num:int) -> Dict[int,int]:
    A: Dict[int,int] = {r:0 for r in shell_counts.keys()}
    F: Dict[int,int] = {r:0 for r in shell_counts.keys()}
    for _ in range(H):
        for r in shell_counts.keys():
            if r == 0:  # skip 1/r singular shell in control/readouts
                continue
            A[r] += rate_num
            if A[r] >= r:
                F[r] += 1
                A[r] -= r
    return F

# ---------- diagnostics ----------
def linreg_y_on_x(xs: List[float], ys: List[float]) -> Tuple[float,float,float]:
    n = len(xs)
    if n < 2: return float("nan"), float("nan"), float("nan")
    xb = sum(xs)/n; yb = sum(ys)/n
    num = sum((x-xb)*(y-yb) for x,y in zip(xs,ys))
    den = sum((x-xb)*(x-xb) for x in xs)
    if den == 0: return float("nan"), float("nan"), float("nan")
    b = num/den
    a = yb - b*xb
    ss_tot = sum((y-yb)*(y-yb) for y in ys)
    ss_res = sum((y-(a+b*x))*(y-(a+b*x)) for x,y in zip(xs,ys))
    r2 = 1.0 - (ss_res/ss_tot if ss_tot>0 else 0.0)
    return b, a, r2

def slope_panel(shell_counts: Dict[int,int], fires: Dict[int,int], H:int,
                r_min:int, r_max:int, n_log_bins:int, fit_mid_frac:float) -> Dict[str,float]:
    rs = [r for r in sorted(shell_counts.keys()) if r_min <= r <= r_max and r>0]
    if len(rs) < max(n_log_bins, 6):  # need enough shells to bin
        return {"slope": float("nan"), "r2": float("nan"), "bins": 0}
    log_lo, log_hi = math.log(rs[0]), math.log(rs[-1])
    edges = [math.exp(log_lo + (log_hi-log_lo)*i/n_log_bins) for i in range(n_log_bins+1)]
    X, Y = [], []
    for i in range(n_log_bins):
        lo, hi = edges[i], edges[i+1]
        shells = [r for r in rs if (r >= math.ceil(lo) and r <= math.floor(hi))]
        if not shells: continue
        # per-cell rate averaged across shells in this log-bin
        rates = [(fires[r]/H) for r in shells]
        r_rep = math.exp((math.log(shells[0]) + math.log(shells[-1]))/2.0)
        X.append(math.log((rs[-1]) / r_rep + 1e-12))
        Y.append(math.log(sum(rates)/len(rates) + 1e-12))
    if len(X) < 4: return {"slope": float("nan"), "r2": float("nan"), "bins": len(X)}
    k = len(X); m = int(round(k*(1.0-fit_mid_frac)/2.0))
    useX = X[m:k-m] if k-2*m >= 2 else X
    useY = Y[m:k-m] if k-2*m >= 2 else Y
    slope, _, r2 = linreg_y_on_x(useX, useY)
    return {"slope": slope, "r2": r2, "bins": k}

def plateau_panel(shell_counts: Dict[int,int], fires: Dict[int,int], H:int,
                  r_min:int, r_max:int, shells_per_bin:int, outer_frac:float) -> Dict[str,float]:
    if r_max <= r_min + shells_per_bin:  # need at least one full bin
        return {"cv": float("nan"), "amp_mean": float("nan"), "nbins": 0}
    width = shells_per_bin
    # enforce full bins and identical edges for both centers
    stop = r_max - ((r_max - r_min + 1) % width)
    if stop < r_min + width - 1:
        return {"cv": float("nan"), "amp_mean": float("nan"), "nbins": 0}
    bins = []
    r = r_min
    while r + width - 1 <= stop:
        ann = list(range(r, r+width))
        val = 0.0
        for rr in ann:
            if rr in shell_counts:
                val += shell_counts[rr] * (fires.get(rr,0)/H)
        bins.append(val)
        r += width
    if not bins:
        return {"cv": float("nan"), "amp_mean": float("nan"), "nbins": 0}
    # outer fraction
    k = len(bins)
    take = max(1, int(round(k * outer_frac)))
    outer = bins[-take:]
    mu = sum(outer)/len(outer)
    if mu == 0.0:
        return {"cv": float("inf"), "amp_mean": 0.0, "nbins": k}
    s2 = sum((v-mu)*(v-mu) for v in outer)/len(outer)
    cv = math.sqrt(s2)/mu
    return {"cv": cv, "amp_mean": mu, "nbins": k}

# ---------- per-center run using common window ----------
def run_center_common(N:int, cx:int, cy:int, H:int, rate_num:int,
                      slope_cfg:dict, plat_cfg:dict,
                      r_min_slope:int, r_min_plat:int, r_max_glob:int) -> Dict[str,float]:
    shells = build_shell_counts(N, cx, cy)
    fires  = simulate_dda_1_over_r(shells, H, rate_num)
    slope = slope_panel(shells, fires, H,
                        r_min=max(1, r_min_slope), r_max=r_max_glob,
                        n_log_bins=int(slope_cfg["n_log_bins"]),
                        fit_mid_frac=float(slope_cfg["fit_mid_frac"]))
    plat  = plateau_panel(shells, fires, H,
                          r_min=max(1, r_min_plat), r_max=r_max_glob,
                          shells_per_bin=int(plat_cfg["shells_per_bin"]),
                          outer_frac=float(plat_cfg["outer_frac"]))
    return {
        "slope": slope["slope"], "r2": slope["r2"], "log_bins": slope["bins"],
        "cv": plat["cv"], "amp_mean": plat["amp_mean"], "lin_nbins": plat["nbins"]
    }

# ---------- main ----------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--manifest", required=True)
    ap.add_argument("--outdir", required=True)
    args = ap.parse_args()

    root = os.path.abspath(args.outdir)
    ensure_dirs(root, ["config","outputs/metrics","outputs/audits","outputs/run_info","logs"])

    with open(args.manifest, "r", encoding="utf-8") as f:
        mani = json.load(f)
    mani_out = os.path.join(root, "config", "manifest_g18_v2.json")
    json_dump(mani_out, mani)

    # env
    write_text(os.path.join(root,"logs","env.txt"),
               "\n".join([f"utc={utc_timestamp()}",
                          f"os={os.name}", f"cwd={os.getcwd()}",
                          f"python={sys.version.split()[0]}"]))

    N      = int(mani["grid"]["N"])
    H      = int(mani["H"])
    rate_n = int(mani["rate_num"])
    omarg  = int(mani["outer_margin"])
    cA     = mani["centers"]["A"];  cB = mani["centers"]["B"]
    slope_cfg = mani["slope"];  plat_cfg = mani["plateau"]

    # common interior r_max across both centers
    R_edge_A = edge_radius(int(cA["cx"]), int(cA["cy"]), N)
    R_edge_B = edge_radius(int(cB["cx"]), int(cB["cy"]), N)
    r_max_glob = min(R_edge_A, R_edge_B) - omarg
    if r_max_glob <= 0:
        raise RuntimeError("Invalid r_max_glob; increase outer_margin or grid N.")

    r_min_slope = int(mani["slope"].get("r_min", 4))
    r_min_plat  = int(mani["plateau"]["r_min"])

    # run both centers on the SAME r_max_glob
    A = run_center_common(N, int(cA["cx"]), int(cA["cy"]), H, rate_n, slope_cfg, plat_cfg,
                          r_min_slope, r_min_plat, r_max_glob)
    B = run_center_common(N, int(cB["cx"]), int(cB["cy"]), H, rate_n, slope_cfg, plat_cfg,
                          r_min_slope, r_min_plat, r_max_glob)

    # metrics CSV
    mpath = os.path.join(root, "outputs/metrics", "g18_pair_metrics.csv")
    with open(mpath, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["center","slope","r2","cv","amp_mean","log_bins","lin_nbins","r_max_glob"])
        w.writerow(["A", f"{A['slope']:.6f}", f"{A['r2']:.6f}", f"{A['cv']:.6f}", f"{A['amp_mean']:.6f}",
                    A["log_bins"], A["lin_nbins"], r_max_glob])
        w.writerow(["B", f"{B['slope']:.6f}", f"{B['r2']:.6f}", f"{B['cv']:.6f}", f"{B['amp_mean']:.6f}",
                    B["log_bins"], B["lin_nbins"], r_max_glob])

    # acceptance
    acc = mani["acceptance"]
    slope_target = float(acc["slope_target"])
    slope_tol    = float(acc["slope_tol_abs"])
    r2_min       = float(acc["r2_min"])
    cv_max       = float(acc["cv_max"])
    delta_slope_max = float(acc["delta_slope_max"])
    delta_cv_max    = float(acc["delta_cv_max"])
    amp_rel_tol     = float(acc["amp_rel_tol"])

    perA_ok = (not math.isnan(A["slope"]) and abs(A["slope"] - slope_target) <= slope_tol and
               (A["r2"] >= r2_min) and (A["cv"] <= cv_max))
    perB_ok = (not math.isnan(B["slope"]) and abs(B["slope"] - slope_target) <= slope_tol and
               (B["r2"] >= r2_min) and (B["cv"] <= cv_max))

    delta_slope = abs(A["slope"] - B["slope"]) if (not math.isnan(A["slope"]) and not math.isnan(B["slope"])) else float("inf")
    delta_cv    = abs(A["cv"] - B["cv"]) if (not math.isnan(A["cv"]) and not math.isnan(B["cv"])) else float("inf")
    amp_ratio   = (B["amp_mean"]/A["amp_mean"]) if (A["amp_mean"]>0 and B["amp_mean"]>0) else float("inf")

    inv_ok  = (delta_slope <= delta_slope_max) and (delta_cv <= delta_cv_max)
    amp_ok  = (abs(amp_ratio - 1.0) <= amp_rel_tol)

    passed = bool(perA_ok and perB_ok and inv_ok and amp_ok)

    audit = {
        "sim": "G18_recentering_v2",
        "gridN": N, "H": H, "rate_num": rate_n, "outer_margin": omarg,
        "r_max_glob": r_max_glob,
        "centerA": A, "centerB": B,
        "pair": {
            "delta_slope": delta_slope,
            "delta_cv": delta_cv,
            "amp_ratio_B_over_A": amp_ratio,
            "inv_ok": inv_ok, "amp_ok": amp_ok
        },
        "per_center_ok": {"A": perA_ok, "B": perB_ok},
        "accept": acc,
        "pass": passed,
        "manifest_hash": sha256_file(mani_out)
    }
    json_dump(os.path.join(root, "outputs", "audits", "g18_audit.json"), audit)

    result_line = ("G18_v2 PASS={p} slopeA={sa:.4f} slopeB={sb:.4f} Δslope={ds:.4f} "
                   "cvA={ca:.4f} cvB={cb:.4f} Δcv={dc:.4f} amp_ratio={ar:.3f} r_max_glob={rg}"
                   .format(p=passed, sa=A["slope"], sb=B["slope"], ds=delta_slope,
                           ca=A["cv"], cb=B["cv"], dc=delta_cv, ar=amp_ratio, rg=r_max_glob))
    write_text(os.path.join(root, "outputs", "run_info", "result_line.txt"), result_line)
    print(result_line)

if __name__ == "__main__":
    main()
